807736
@@ -25,6 +25,9 @@
import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiFunction;
+import java.util.function.Consumer;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.FileStatus;
@@ -53,6 +56,7 @@
import org.apache.hadoop.security.token.Token;
 import org.apache.yetus.audience.InterfaceAudience;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+import org.apache.hbase.thirdparty.com.google.common.annotations.VisibleForTesting;
 import org.apache.hadoop.hbase.shaded.protobuf.generated.ClientProtos;
 import org.apache.hadoop.hbase.shaded.protobuf.generated.ClientProtos.BulkLoadHFileRequest;
 import org.apache.hadoop.hbase.shaded.protobuf.generated.ClientProtos.CleanupBulkLoadRequest;
@@ -106,6 +110,7 @@
public class SecureBulkLoadManager {
   private Path baseStagingDir;
 
   private UserProvider userProvider;
+  private ConcurrentHashMap<UserGroupInformation, Integer> ugiReferenceCounter;
   private Connection conn;
 
   SecureBulkLoadManager(Configuration conf, Connection conn) {
@@ -116,6 +121,7 @@
public class SecureBulkLoadManager {
   public void start() throws IOException {
     random = new SecureRandom();
     userProvider = UserProvider.instantiate(conf);
+    ugiReferenceCounter = new ConcurrentHashMap<>();
     fs = FileSystem.get(conf);
     baseStagingDir = new Path(FSUtils.getRootDir(conf), HConstants.BULKLOAD_STAGING_DIR_NAME);
 
@@ -158,7 +164,7 @@
public class SecureBulkLoadManager {
     } finally {
       UserGroupInformation ugi = getActiveUser().getUGI();
       try {
-        if (!UserGroupInformation.getLoginUser().equals(ugi)) {
+        if (!UserGroupInformation.getLoginUser().equals(ugi) && !isUserReferenced(ugi)) {
           FileSystem.closeAllForUGI(ugi);
         }
       } catch (IOException e) {
@@ -167,6 +173,38 @@
public class SecureBulkLoadManager {
     }
   }
 
+  private Consumer<HRegion> fsCreatedListener;
+
+  @VisibleForTesting
+  void setFsCreatedListener(Consumer<HRegion> fsCreatedListener) {
+    this.fsCreatedListener = fsCreatedListener;
+  }
+
+
+  private void incrementUgiReference(UserGroupInformation ugi) {
+    ugiReferenceCounter.merge(ugi, 1, new BiFunction<Integer, Integer, Integer>() {
+      @Override
+      public Integer apply(Integer oldvalue, Integer value) {
+        return ++oldvalue;
+      }
+    });
+  }
+
+  private void decrementUgiReference(UserGroupInformation ugi) {
+    ugiReferenceCounter.computeIfPresent(ugi,
+        new BiFunction<UserGroupInformation, Integer, Integer>() {
+          @Override
+          public Integer apply(UserGroupInformation key, Integer value) {
+            return value > 1 ? --value : null;
+          }
+      });
+  }
+
+  private boolean isUserReferenced(UserGroupInformation ugi) {
+    Integer count = ugiReferenceCounter.get(ugi);
+    return count != null && count > 0;
+  }
+
   public Map<byte[], List<Path>> secureBulkLoadHFiles(final HRegion region,
       final BulkLoadHFileRequest request) throws IOException {
     final List<Pair<byte[], String>> familyPaths = new ArrayList<>(request.getFamilyPathCount());
@@ -208,6 +246,7 @@
public class SecureBulkLoadManager {
     Map<byte[], List<Path>> map = null;
 
     try {
+      incrementUgiReference(ugi);
       // Get the target fs (HBase region server fs) delegation token
       // Since we have checked the permission via 'preBulkLoadHFile', now let's give
       // the 'request user' necessary token to operate on the target fs.
@@ -237,6 +276,9 @@
public class SecureBulkLoadManager {
                 fs.setPermission(stageFamily, PERM_ALL_ACCESS);
               }
             }
+            if (fsCreatedListener != null) {
+              fsCreatedListener.accept(region);
+            }
             //We call bulkLoadHFiles as requesting user
             //To enable access prior to staging
             return region.bulkLoadHFiles(familyPaths, true,
@@ -248,6 +290,7 @@
public class SecureBulkLoadManager {
         }
       });
     } finally {
+      decrementUgiReference(ugi);
       if (region.getCoprocessorHost() != null) {
         region.getCoprocessorHost().postBulkLoadHFile(familyPaths, map);
       }
